"""

Streak plots code updated
:version: 11/05/2020

"""
from hexapole_old import HexVector, Assembly
from Verlet import verletFlyer, loadFinal, rewind

import numpy as np
import logging
import matplotlib.pyplot as plt
import matplotlib as mpl
import FigureSetup
import os
import os.path
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from mpl_toolkits.axes_grid1.inset_locator import mark_inset
from matplotlib.ticker import FormatStrFormatter
from mpl_toolkits.axes_grid1 import make_axes_locatable
from matplotlib import ticker

######################THINGS TO CHANGE#############################################
# targetVel, focal length, guide parameters, input/output folder name, name of npz file to be saved x2 (or 3)



# Set up logging and message detail level. Set the level to logging.INFO for a
# quieter output.
logger = logging.getLogger()
ch = logging.StreamHandler()
ch.setLevel(logging.DEBUG)
ch.setFormatter(logging.Formatter('%(name)s - %(name)s - %(message)s'))
logging.getLogger().addHandler(ch)


def Guide (HA1e2ypos, HA3e4ypos, HA4zshift, heightBlade1, heightBlade2):

    ### Guide specifications
    HA1zpos = 241.15
    HA3zpos = HA1zpos+24.0+5.0+5.0 #HA1zpos+2f(MaxVel)+2*(halfwidth arrays)+(arbitrary spacing)
            
    ### double shift design
    h1 = HexVector('Bvec_ri3lh7B1415.h5', position=[0.0, HA1e2ypos, HA1zpos]) 
    h2 = HexVector('Bvec_ri3lh7B1415.h5', position=[0.0, HA1e2ypos, HA1zpos+2*focallength+0.0])
    h3 = HexVector('Bvec_ri3lh7B1415.h5', position=[0.0, HA3e4ypos, HA3zpos]) 
    h4 = HexVector('Bvec_ri3lh7B1415.h5', position=[0.0, HA3e4ypos, HA3zpos+2*focallength+HA4zshift])    
    hh = Assembly([h1, h2, h3, h4])
    return hh

# Parameters to vary
input_folder = r'C:\Users\brhgroup\OneDrive - Nexus365\Lab\Guide simulation\Article\Data for ORA\Simulation\Simulate signal\Input'
output_folder = r'C:\Users\brhgroup\OneDrive - Nexus365\Lab\Guide simulation\Article\Data for ORA\Simulation\Simulate signal'


states=[0]  #not for streak plots
# states=[0]#,1,2,3]  #not for streak plots
targetVel = 0.350 #x10^3 ms-1
focallength = 12.7 
tolerance = 0.010 #+/- ms-1
maxvel = targetVel+tolerance
minvel = targetVel-tolerance
posCoil12 = 226.6
detectionPos = 350.00 #350.00 #ion trap position (using ion trap area detection)

# Guide specifications
posBlade1 = 270.65
posBlade2 = 310.0
HA1e2ypos = 1.5
HA3e4ypos = 0.5 
HA4zshift = 0.0
heightBlade1 = 1.1
heightBlade2 = -0.2

# HA dimensions (to draw guide)
HAwidth = 7.0
HAradiusINT = 3.0
HAhalfwidth = (HAwidth/2)
HAradiusEXT = HAradiusINT+4.0

########TestGuide.py##########################################################################################
##############################################################################################################
######## Generates a streak image of the decelerated H atom density through the guide (LFS only, divided in Target, Slower and Faster - HFS ones are all lost) --- DISCARDING SKIMMED PARTICLES 

# Select the start and end z-plane relative to the centre of the magnet.
start = 231.6 #231.6 = exit of coil 12 vs 226.6 = middle of coil 12
end = 350.0  ##350.0 = ion trap centre
# Set up the number of bins in the flight direction and image height.
nsteps = 200 #200
nheight = 200
hmax = 10.0
flightImageTARGET = np.zeros((nsteps, nheight))
# flightImageCollTARGET = np.zeros((nsteps, nheight))
flightImageFASTER = np.zeros((nsteps, nheight))
flightImageSLOWER = np.zeros((nsteps, nheight))
flightBins = np.linspace(-hmax, hmax, nheight+1)
steps = np.linspace(start, end, nsteps)

# Generate the guide
hh = Guide(HA1e2ypos, HA3e4ypos, HA4zshift, heightBlade1, heightBlade2)

# datafile = os.path.join(output_folder,'GuideAnalysis1milOutput2Skimmed_1p5_0p5_1p1_m0p2_' + str(targetVel*1000) + 'pm' + str(tolerance*1000) +
#                     'quantile' + str(0.8) + '.npz')
datafile = os.path.join(output_folder, 'GuideAnalysis1milOutput2Skimmed_1p5_0p5_1p1_m0p2_targetVel350pm10_final.npz')
if os.path.exists(datafile):
    alldata = np.load(datafile)
     
    flightImageTARGET = alldata['flightImageTARGET'] 
    # flightImageCollTARGET = alldata['flightImageCollTARGET']
    flightImageFASTER = alldata['flightImageFASTER']
    flightImageSLOWER = alldata['flightImageSLOWER']
else:
    ##########TARGET particles streak
    # Load some atoms.
    pos, vel, times = loadFinal(input_folder, states=[0])
    # Rewind loaded atoms to coil12
    pos, vel, times = rewind(start, pos, vel, times)

    # Select target particles (with vel within % tolerance). 
    indmaxvel = np.where(vel[:,2]<maxvel)[0]
    indminvel = np.where(vel[:,2]>minvel)[0]
    ind = reduce(np.intersect1d, (indmaxvel, indminvel))
    pos = pos[ind, :]
    vel = vel[ind, :]
    times = times[ind]

    # Step along flight direction, flying particles to each column of pixels.
    for j, step in enumerate(steps):
        print '{}/{}'.format(j, len(steps))
        pos, vel, times = verletFlyer(pos, vel, times, state=0, hexapole=hh, totalZ=step, dt=0.5, totalTime=100)
        
        # PosRef = np.mean(pos[:,2])
        # PosRef = pos[0,2]
        # PosRef = np.median(pos[:,2])
        PosRef = np.quantile(pos[:,2], 0.75)
        if posBlade1 <= PosRef <= posBlade1+((end-start)/nsteps):
            ind_notskimmed1 = np.where(pos[:,1]>heightBlade1)[0]
            ind_notColl1 = hh.notCollided(pos)
            ind_notSkCol1 = reduce(np.intersect1d, (ind_notskimmed1, ind_notColl1))
            flightImageTARGET[j,:] = np.histogram(pos[ind_notSkCol1,1], flightBins)[0]
            pos = pos[ind_notSkCol1,:]
        if posBlade2 <= PosRef <= posBlade2+((end-start)/nsteps):
            ind_notskimmed2 = np.where(pos[:,1]<heightBlade2)[0] 
            ind_notColl2 = hh.notCollided(pos)
            ind_notSkCol2 = reduce(np.intersect1d, (ind_notskimmed2, ind_notColl2))
            flightImageTARGET[j,:] = np.histogram(pos[ind_notSkCol2,1], flightBins)[0]
            pos = pos[ind_notSkCol2,:]
        else:
            # Pick out particles that have not collided.
            ind = hh.notCollided(pos)
            flightImageTARGET[j,:] = np.histogram(pos[ind,1], flightBins)[0]
            
            
    ##########FASTER particles streak
    # Load some atoms.
    pos, vel, times = loadFinal(input_folder, states=[0])
    # Rewind loaded atoms to coil12
    pos, vel, times = rewind(start, pos, vel, times)

    # Select fast particles (with vel > target velocity). 
    ind = np.where(vel[:,2]>maxvel)[0]
    pos = pos[ind, :]
    vel = vel[ind, :]
    times = times[ind]

    # Step along flight direction, flying particles to each column of pixels.
    for j, step in enumerate(steps):
        print '{}/{}'.format(j, len(steps))
        pos, vel, times = verletFlyer(pos, vel, times, state=0, hexapole=hh, totalZ=step, dt=0.5, totalTime=100)
        
        # PosRef = np.mean(pos[:,2])
        # PosRef = pos[0,2]
        # PosRef = np.median(pos[:,2])
        PosRef = np.quantile(pos[:,2], 0.75)
        if posBlade1 <= PosRef <= posBlade1+((end-start)/nsteps):
            ind_notskimmed1 = np.where(pos[:,1]>heightBlade1)[0]
            ind_notColl1 = hh.notCollided(pos)
            ind_notSkCol1 = reduce(np.intersect1d, (ind_notskimmed1, ind_notColl1))
            flightImageFASTER[j,:] = np.histogram(pos[ind_notSkCol1,1], flightBins)[0]
            pos = pos[ind_notSkCol1,:]
        if posBlade2 <= PosRef <= posBlade2+((end-start)/nsteps):
            ind_notskimmed2 = np.where(pos[:,1]<heightBlade2)[0] 
            ind_notColl2 = hh.notCollided(pos)
            ind_notSkCol2 = reduce(np.intersect1d, (ind_notskimmed2, ind_notColl2))
            flightImageFASTER[j,:] = np.histogram(pos[ind_notSkCol2,1], flightBins)[0]
            pos = pos[ind_notSkCol2,:]
        else:
            # Pick out particles that have not collided.
            ind = hh.notCollided(pos)
            flightImageFASTER[j,:] = np.histogram(pos[ind,1], flightBins)[0]
            
        
    ##########SLOWER particles streak
    # Load some atoms.
    pos, vel, times = loadFinal(input_folder, states=[0])
    # Rewind loaded atoms to coil12
    pos, vel, times = rewind(start, pos, vel, times)

    # Select slower particles (with vel < target velocity). 
    ind = np.where(vel[:,2]<minvel)[0]
    pos = pos[ind, :]
    vel = vel[ind, :]
    times = times[ind]

    # Step along flight direction, flying particles to each column of pixels.
    for j, step in enumerate(steps):
        print '{}/{}'.format(j, len(steps))
        pos, vel, times = verletFlyer(pos, vel, times, state=0, hexapole=hh, totalZ=step, dt=0.5, totalTime=100)
        
        # PosRef = np.mean(pos[:,2])
        # PosRef = pos[0,2]
        # PosRef = np.median(pos[:,2])
        PosRef = np.quantile(pos[:,2], 0.75)
        if posBlade1 <= PosRef <= posBlade1+((end-start)/nsteps):
            ind_notskimmed1 = np.where(pos[:,1]>heightBlade1)[0]
            ind_notColl1 = hh.notCollided(pos)
            ind_notSkCol1 = reduce(np.intersect1d, (ind_notskimmed1, ind_notColl1))
            flightImageSLOWER[j,:] = np.histogram(pos[ind_notSkCol1,1], flightBins)[0]
            pos = pos[ind_notSkCol1,:]
        if posBlade2 <= PosRef <= posBlade2+((end-start)/nsteps):
            ind_notskimmed2 = np.where(pos[:,1]<heightBlade2)[0] 
            ind_notColl2 = hh.notCollided(pos)
            ind_notSkCol2 = reduce(np.intersect1d, (ind_notskimmed2, ind_notColl2))
            flightImageSLOWER[j,:] = np.histogram(pos[ind_notSkCol2,1], flightBins)[0]
            pos = pos[ind_notSkCol2,:]
        else:
            # Pick out particles that have not collided.
            ind = hh.notCollided(pos)
            flightImageSLOWER[j,:] = np.histogram(pos[ind,1], flightBins)[0]
            

    np.savez(datafile, flightImageTARGET=flightImageTARGET, flightImageFASTER=flightImageFASTER, flightImageSLOWER=flightImageSLOWER)
    
        
    
############Figure
f, (ax1,ax2,ax3) = FigureSetup.new_figure(nrows=3, ncols=1, sharex='all', sharey='all')

print np.max(flightImageTARGET), np.max(flightImageFASTER), np.max(flightImageSLOWER)
vmaxT=np.max(flightImageTARGET) 
vmaxF=np.max(flightImageFASTER) 
vmaxS=np.max(flightImageSLOWER) 
int = 70 #Percentage increase in intensity of image
im1 = ax1.imshow(flightImageTARGET.T, origin='lower', cmap='viridis', vmax=vmaxT-(vmaxT*int/100), extent=(start, end, -hmax, hmax), aspect=2) 
im2 = ax2.imshow(flightImageFASTER.T, origin='lower', cmap='viridis', vmax=vmaxF-(vmaxF*int/100), extent=(start, end, -hmax, hmax), aspect=2) 
im3 = ax3.imshow(flightImageSLOWER.T, origin='lower', cmap='viridis', vmax=vmaxS-(vmaxS*int/100), extent=(start, end, -hmax, hmax), aspect=2) 

print vmaxT-(vmaxT*int/100), vmaxF-(vmaxF*int/100), vmaxS-(vmaxS*int/100)

for m in hh.magnetList:
    # Draw the outline of magnetlist:the magnets as rectangle, then transform
    # this into the lab frame for plotting.
    m1 = np.array([[0.0, -m.ri, -HAhalfwidth], [0.0, -HAradiusEXT, -HAhalfwidth],
        [0, -HAradiusEXT, HAhalfwidth], [0, -m.ri, HAhalfwidth]])
    m2 = np.array([[0.0,  m.ri, -HAhalfwidth], [0.0,  HAradiusEXT, -HAhalfwidth],
        [0,  HAradiusEXT, HAhalfwidth], [0,  m.ri, HAhalfwidth]])
    
    m1 = m.toLab(m1)
    m2 = m.toLab(m2)

    # Draw the magnets.
    ax1.add_patch(mpl.patches.Polygon(np.vstack((m1[:,2], m1[:,1])).T,
            fill=True, facecolor=[0.123463, 0.581687, 0.547445]))
    ax1.add_patch(mpl.patches.Polygon(np.vstack((m2[:,2], m2[:,1])).T,
            fill=True, facecolor=[0.123463, 0.581687, 0.547445]))
    ax2.add_patch(mpl.patches.Polygon(np.vstack((m1[:,2], m1[:,1])).T,
            fill=True, facecolor=[0.123463, 0.581687, 0.547445]))
    ax2.add_patch(mpl.patches.Polygon(np.vstack((m2[:,2], m2[:,1])).T,
            fill=True, facecolor=[0.123463, 0.581687, 0.547445])) #u'#09B190' Chris's colour
    ax3.add_patch(mpl.patches.Polygon(np.vstack((m1[:,2], m1[:,1])).T,
            fill=True, facecolor=[0.123463, 0.581687, 0.547445]))
    ax3.add_patch(mpl.patches.Polygon(np.vstack((m2[:,2], m2[:,1])).T,
            fill=True, facecolor=[0.123463, 0.581687, 0.547445]))

####Plotting blades over image
ax1.plot([posBlade1, posBlade1], [-hmax-1, heightBlade1], color='w', linestyle='-', linewidth=2.0)
ax1.plot([posBlade2, posBlade2], [heightBlade2, hmax+1], color='w', linestyle='-', linewidth=2.0)
ax2.plot([posBlade1, posBlade1], [-hmax-1, heightBlade1], color='w', linestyle='-', linewidth=2.0)
ax2.plot([posBlade2, posBlade2], [heightBlade2, hmax+1], color='w', linestyle='-', linewidth=2.0)
ax3.plot([posBlade1, posBlade1], [-hmax-1, heightBlade1], color='w', linestyle='-', linewidth=2.0)
ax3.plot([posBlade2, posBlade2], [heightBlade2, hmax+1], color='w', linestyle='-', linewidth=2.0)

ax1.plot([start, end], [0, 0], color='w', linestyle=':', linewidth=0.5)
ax2.plot([start, end], [0, 0], color='w', linestyle=':', linewidth=0.5)
ax3.plot([start, end], [0, 0], color='w', linestyle=':', linewidth=0.5)

###Plotting ion trap position
ax1.plot([349.5, 349.5], [-1.3, 1.3], color='silver', linestyle='-', linewidth=2)
ax2.plot([349.5, 349.5], [-1.3, 1.3], color='silver', linestyle='-', linewidth=2)
ax3.plot([349.5, 349.5], [-1.3, 1.3], color='silver', linestyle='-', linewidth=2)
#####Annotations
# ax1.annotate('$\mathbf{v_z = (''%.0f' %(targetVel*1000)+'$\pm$''%.0f' %(tolerance*1000)+') ms^{-1}}$', xy=(312, -6), color='w')
# ax2.annotate('$\mathbf{v_z > ''%.0f' %((targetVel+tolerance)*1000)+' ms^{-1}}$', xy=(312,-6), color='w', fontweight='bold')
# ax3.annotate('$\mathbf{v_z < ''%.0f' %((targetVel-tolerance)*1000)+' ms^{-1}}$', xy=(312,-6), color='w', fontweight='bold')

ax1.annotate("${v_z}$ = (350 $\pm$ 10) ms$^{-1}$", (312, -6), textcoords = 'data', size = 10, color = 'w')
ax2.annotate("${v_z}$ $>$ 360 ms$^{-1}$", (312, -6), textcoords = 'data', size = 10, color = 'w')
ax3.annotate("${v_z}$ $<$ 340 ms$^{-1}$", (312, -6), textcoords = 'data', size = 10, color = 'w')

ax1.set_ylabel('y (mm)')
ax2.set_ylabel('y (mm)')
ax3.set_ylabel('y (mm)')
ax3.set_xlabel('z (mm)')
ax1.set_xlim([start, end])
ax1.set_ylim([-hmax, hmax]) ##need to use fig_width = 20 fig_height = 25 (cm) or similar ratio


divider1 = make_axes_locatable(ax1)
cax1 = divider1.append_axes("right", size="1%", pad=-2)

divider2 = make_axes_locatable(ax2)
cax2 = divider2.append_axes("right", size="1%", pad=-2)

divider3 = make_axes_locatable(ax3)
cax3 = divider3.append_axes("right", size="1%", pad=-2)

tick_locator = ticker.MaxNLocator(nbins = 3)

cbar1 = plt.colorbar(im1, cax = cax1)
cbar1.locator = tick_locator
cbar1.ax.yaxis.set_major_locator(ticker.AutoLocator())
cbar1.update_ticks()
cbar1.ax.set_yticklabels(['0', '0.017', '0.034'])

cbar2 = plt.colorbar(im2, cax = cax2)
cbar2.locator = tick_locator
cbar2.ax.yaxis.set_major_locator(ticker.AutoLocator())
cbar2.update_ticks()
cbar2.ax.set_yticklabels(['0', '0.337', '0.674'])
# 
cbar3 = plt.colorbar(im3, cax = cax3)
cbar3.locator = tick_locator
cbar3.ax.yaxis.set_major_locator(ticker.AutoLocator())
cbar3.update_ticks()
cbar3.ax.set_yticklabels(['0', '0.022', '0.045'])

ax1.set_ylim(-8,8)
f.subplots_adjust(hspace=-0.5, wspace=0.05)
# plt.tight_layout()
# plt.savefig(img_name + ".png", format = "png", dpi = 1200)
plt.savefig("Final_1p5_0p5_1p1_m0p2_target350pm10.pdf", format = "pdf")
plt.show()

